{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import logging\n",
    "from box import Box\n",
    "from datetime import datetime\n",
    "import sys\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !git clone https://github.com/facebookresearch/DrQA.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fast_bert import BertLearner, BertDataBunch, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_colwidth', -1)\n",
    "run_start_time = datetime.today().strftime('%Y-%m-%d_%H-%M-%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = Path(\"../sample_data/imdb_movie_reviews\")\n",
    "DATA_PATH = PATH/'data'\n",
    "LABEL_PATH = PATH/'label'\n",
    "OUT_PATH = PATH/'.output'\n",
    "OUT_PATH.mkdir(exist_ok=True)\n",
    "\n",
    "MODEL_PATH=OUT_PATH/'model'\n",
    "MODEL_PATH.mkdir(exist_ok=True)\n",
    "\n",
    "LOG_PATH = OUT_PATH/'logs/'\n",
    "LOG_PATH.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Box({\n",
    "    \"run_text\": \"ibdm_reviews\",\n",
    "    \"max_seq_length\": 512,\n",
    "    \"batch_size\": 8,\n",
    "    \"learning_rate\": 5e-3,\n",
    "    \"num_train_epochs\": 6,\n",
    "    \"fp16\": False,\n",
    "    \"model_name\": 'distilroberta-base',\n",
    "    \"model_type\": 'roberta'\n",
    "})\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.device_count() else torch.device('cpu')\n",
    "if torch.cuda.device_count() > 1:\n",
    "    args.multi_gpu = True\n",
    "else:\n",
    "    args.multi_gpu = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, args[\"run_text\"]))\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "    datefmt='%m/%d/%Y %H:%M:%S',\n",
    "    handlers=[\n",
    "        logging.FileHandler(logfile),\n",
    "        logging.StreamHandler(sys.stdout)\n",
    "    ])\n",
    "\n",
    "logger = logging.getLogger()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 01:52:02 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json from cache at /Users/kaushaltrivedi/.cache/torch/transformers/5f11352d3c3e932888f3ba75bc24579eacb5d1596d39ce56166aeae8fd363df8.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b\n",
      "12/15/2019 01:52:02 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt from cache at /Users/kaushaltrivedi/.cache/torch/transformers/01f63a14ad93494c050af2090c59930fb787bdfb347c4cad7ce9063e1a5fe140.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n",
      "12/15/2019 01:52:02 - INFO - root -   Loading features from cached file ../sample_data/imdb_movie_reviews/data/cache/cached_roberta_train_multi_class_512_train_sample.csv\n",
      "12/15/2019 01:52:02 - INFO - root -   Loading features from cached file ../sample_data/imdb_movie_reviews/data/cache/cached_roberta_dev_multi_class_512_val_sample.csv\n"
     ]
    }
   ],
   "source": [
    "databunch = BertDataBunch(DATA_PATH, LABEL_PATH, args.model_name, \n",
    "                          train_file=\"train_sample.csv\", val_file=\"val_sample.csv\",\n",
    "                          batch_size_per_gpu=args.batch_size, \n",
    "                          max_seq_length=args.max_seq_length, \n",
    "                          multi_gpu=args.multi_gpu,\n",
    "                          multi_label=False,\n",
    "                          model_type=args.model_type\n",
    "                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = [{\"name\": \"accuracy\", \"function\": accuracy}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 01:52:03 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json from cache at /Users/kaushaltrivedi/.cache/torch/transformers/d52ced8fd31ba6aa311b6eeeae65178cca00ddd6333c087be4601dc46c20bd96.0cc9c825897545d1d8c78f2343db2a450d3531eb9b0fb77a96cc487ebbb74210\n",
      "12/15/2019 01:52:03 - INFO - transformers.configuration_utils -   Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"is_decoder\": false,\n",
      "  \"layer_norm_eps\": 1e-05,\n",
      "  \"max_position_embeddings\": 514,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 6,\n",
      "  \"num_labels\": 2,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"output_past\": true,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 1,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 50265\n",
      "}\n",
      "\n",
      "12/15/2019 01:52:05 - INFO - transformers.modeling_utils -   loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin from cache at /Users/kaushaltrivedi/.cache/torch/transformers/98d2960e6de3e0a3a34b8566dff60ab2fa946e85134dd5e4722c54ff58044a96.52b6ec356fb91985b3940e086d1b2ebf8cd40f8df0ba1cabf4cac27769dee241\n",
      "12/15/2019 01:52:08 - INFO - transformers.modeling_utils -   Weights of RobertaForSequenceClassification not initialized from pretrained model: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "12/15/2019 01:52:08 - INFO - transformers.modeling_utils -   Weights from pretrained model not used in RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']\n"
     ]
    }
   ],
   "source": [
    "learner = BertLearner.from_pretrained_model(databunch, args.model_name, metrics=metrics, \n",
    "                                            device=device, multi_gpu=args.multi_gpu, is_fp16=args.fp16,\n",
    "                                            multi_label=False, logging_steps=0,\n",
    "                                            output_dir=OUT_PATH, logger=logger\n",
    "                                           )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 01:52:08 - INFO - root -   Running evaluation\n",
      "12/15/2019 01:52:08 - INFO - root -     Num examples = 50\n",
      "12/15/2019 01:52:08 - INFO - root -     Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='4' class='' max='4', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [4/4 00:32<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'loss': 0.6959984004497528, 'accuracy': 0.38}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.validate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../sample_data/imdb_movie_reviews/.output/tensorboard\n",
      "12/15/2019 01:52:41 - INFO - root -   ***** Running training *****\n",
      "12/15/2019 01:52:41 - INFO - root -     Num examples = 100\n",
      "12/15/2019 01:52:41 - INFO - root -     Num Epochs = 2\n",
      "12/15/2019 01:52:41 - INFO - root -     Total train batch size (w. parallel, distributed & accumulation) = 8\n",
      "12/15/2019 01:52:41 - INFO - root -     Gradient Accumulation steps = 1\n",
      "12/15/2019 01:52:41 - INFO - root -     Total optimization steps = 26\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 01:56:18 - INFO - root -   Running evaluation\n",
      "12/15/2019 01:56:18 - INFO - root -     Num examples = 50\n",
      "12/15/2019 01:56:18 - INFO - root -     Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='4' class='' max='4', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [4/4 00:38<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 01:56:57 - INFO - root -   eval_loss after epoch 1: 0.47291456162929535: \n",
      "12/15/2019 01:56:57 - INFO - root -   eval_accuracy after epoch 1: 0.82: \n",
      "12/15/2019 01:56:57 - INFO - root -   lr after epoch 1: 0.0025\n",
      "12/15/2019 01:56:57 - INFO - root -   train_loss after epoch 1: 0.6882386070031387\n",
      "12/15/2019 01:56:57 - INFO - root -   \n",
      "\n",
      "12/15/2019 02:01:45 - INFO - root -   Running evaluation\n",
      "12/15/2019 02:01:45 - INFO - root -     Num examples = 50\n",
      "12/15/2019 02:01:45 - INFO - root -     Batch size = 16\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='4' class='' max='4', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [4/4 00:35<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 02:02:21 - INFO - root -   eval_loss after epoch 2: 0.3704464165493846: \n",
      "12/15/2019 02:02:21 - INFO - root -   eval_accuracy after epoch 2: 0.84: \n",
      "12/15/2019 02:02:21 - INFO - root -   lr after epoch 2: 0.0\n",
      "12/15/2019 02:02:21 - INFO - root -   train_loss after epoch 2: 0.23913450768360725\n",
      "12/15/2019 02:02:21 - INFO - root -   \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(26, 0.46368655734337294)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.fit(2, args.learning_rate, validate=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 02:02:22 - INFO - transformers.configuration_utils -   Configuration saved in ../sample_data/imdb_movie_reviews/.output/model_out/config.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.modeling_utils -   Model weights saved in ../sample_data/imdb_movie_reviews/.output/model_out/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "learner.save_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fast_bert.prediction import BertClassificationPredictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   Model name '../sample_data/imdb_movie_reviews/.output/model_out' not found in model shortcut name list (roberta-base, roberta-large, roberta-large-mnli, distilroberta-base, roberta-base-openai-detector, roberta-large-openai-detector). Assuming '../sample_data/imdb_movie_reviews/.output/model_out' is a path or url to a directory containing tokenizer files.\n",
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   loading file ../sample_data/imdb_movie_reviews/.output/model_out/vocab.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   loading file ../sample_data/imdb_movie_reviews/.output/model_out/merges.txt\n",
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   loading file ../sample_data/imdb_movie_reviews/.output/model_out/added_tokens.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   loading file ../sample_data/imdb_movie_reviews/.output/model_out/special_tokens_map.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.tokenization_utils -   loading file ../sample_data/imdb_movie_reviews/.output/model_out/tokenizer_config.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.configuration_utils -   loading configuration file ../sample_data/imdb_movie_reviews/.output/model_out/config.json\n",
      "12/15/2019 02:02:22 - INFO - transformers.configuration_utils -   Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"is_decoder\": false,\n",
      "  \"layer_norm_eps\": 1e-05,\n",
      "  \"max_position_embeddings\": 514,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 6,\n",
      "  \"num_labels\": 2,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"output_past\": true,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 1,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 50265\n",
      "}\n",
      "\n",
      "12/15/2019 02:02:22 - INFO - transformers.modeling_utils -   loading weights file ../sample_data/imdb_movie_reviews/.output/model_out/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "predictor = BertClassificationPredictor(OUT_PATH/'model_out', LABEL_PATH, multi_label=False, model_type=args.model_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12/15/2019 02:03:05 - INFO - root -   Writing example 0 of 1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[[('0', 0.934996485710144), ('1', 0.06500352919101715)]]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.predict_batch([\"i hate you\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
